FedSAM
What is Federated Learning?
In classical Machine Learning, there exist a model and data. Let us consider the model to be a neural network or a classical linear regression. We can train this model using our available data to perform a useful task. This task might involve object detection, converting an audio file into text, or playing games such as Go or chess by this model.
However, the data we train on might not always be owned by us or open source. Especially considering data privacy, since users’ data are not shared, these data essentially remain untapped, benefiting neither the users nor the companies. For instance, mobile phones are one of the richest sources in this regard. Location data, word data from keyboards, in addition to location and speed information from vehicles, etc., are examples.
Normally, these data would be collected at a central point, and training would be conducted as mentioned in the first paragraph. In Federated Learning, however, the approach is the complete opposite. Training occurs on local devices with their own data, and subsequently, these computations update the main model like data. In this way, both data privacy is preserved, and models are developed from inaccessible data.
For Federated Learning, the Flower Framework, which is compatible with most Deep Learning and Machine Learning Frameworks and is also fast and efficient, was used. Given that the Flower Framework, which is currently considered a state-of-the-art model in the market, also supports the HuggingFace Transformers Framework, it can be integrated simply.
Flower Client
Initially, we commence with the implementation of the Flower client. Within the client class, we first define the Train_Federated model prepared for Federated Learning as discussed in the sections 1.1 HF Transformers SAM’s Fine Tune and 1.4 Fine Tune of SAM.
example To federate our example to multiple clients, we first need to write our Flower client class (inheriting from
flwr.client.NumPyClient
). This is very easy, as our model is a standardPyTorch
model
class SAMClient(fl.client.NumPyClient):
"""
Flower client implementing SAM
Args:
dataset_root (str): Root directory of the dataset
image_subfolder (str): Name of the image subfolder
annotation_subfolder (str): Name of the annotation subfolder
batch_size (int): Batch size for training
num_epochs (int): Number of epochs for training
"""
In the SAMClient class, since the methods are defined within the training class, here we will only provide a general overview of the methods.
get_parameters
def get_parameters(self, **kwargs):
# returns initial parameters (before training)
return self.train_model.get_model_parameters(self.train_model.model)
- The
get_parameters
function retrieves the server client’s parameters using the get_model_parameters function from the Training class.
def get_model_parameters(self, model):
"""Get model parameters as a list of NumPy ndarrays.
Args:
model (nn.Module): Model to get the parameters from.
Returns:
list: List of NumPy ndarrays representing the parameters."""
return [val.cpu().numpy() for _, val in model.state_dict().items()]
def set_parameters(self, parameters):
# set model parameters received from the server
self.train_model.set_model_parameters(self.train_model.model, parameters)
- Conversely, the
set_parameters
function facilitates the transmission of trained parameters to the server, utilizing the set_model_parameters function from the Training class in a similar manner.
#Flowe
def set_model_parameters(self, model, parameters):
"""Set model parameters.
Args:
model (nn.Module): Model to set the parameters for.
parameters (list): List of NumPy ndarrays representing the parameters.
"""
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict)
Flower Server
In this section, we will create a Server for the Client we have constructed. It is necessary first to determine the strategy to be employed. Flower comes with popular ready-to-use federated learning strategies. Strategies define rules that regulate how local updates will be combined and managed during the iterative model training process. They may also include mechanisms to address communication constraints, mitigate the effects of outlier participants, and improve model convergence rates.
Research continues to focus on extending and enhancing existing strategies to address new challenges and scenarios in federated learning, including supported learning and hybrid federated dual coordinate ascent. Additionally, Flower offers the possibility to implement Custom Federated Learning strategies. However, among the general strategies used on the server side, FedAvg is the most effective, and when combined with appropriate clients for processing non-IID data in image-related tasks (for example, MAML), it can be highly effective for such tasks. This will be our strategy on the server side. The sections provided below are merely examples, yet federated training can be initiated or extended in this manner.
# strategy selection
strategy = fl.server.strategy.FedAvg(
min_fit_clients=2,
min_available_clients=2
)
```
In this section, as mentioned in the previous paragraph, we define our strategy as FedAvg, thus employing a function accordingly using the Flower Framework. Subsequently, we specify that it will start with at least 2 clients.
```python
# server configuration
server_config = fl.server.ServerConfig(
num_rounds=10
)
Here, we define the number of rounds the server will operate.
Below, we initialize our server
# server initialization
fl.server.start_server(
server_address="0.0.0.0:8080",
config=server_config,
strategy=strategy
)
#Flower Segment Anything Model Training
In the client section, we had described our training in the section 1.4 Fine Tune of SAM for local training. Here, certain functions exist in the Client that are not present in local training, as follows:.
def set_model_parameters(self, model, parameters):
"""Set model parameters.
Args:
model (nn.Module): Model to set the parameters for.
parameters (list): List of NumPy ndarrays representing the parameters.
"""
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict)
set_model_parameters
facilitates the return transmission of model parameters from the server to the clients after they are retrieved from the clients usingget_model_parameters
.
def get_model_parameters(self, model):
"""Get model parameters as a list of NumPy ndarrays.
Args:
model (nn.Module): Model to get the parameters from.
Returns:
list: List of NumPy ndarrays representing the parameters."""
return [val.cpu().numpy() for _, val in model.state_dict().items()]
-
get_model_parameters
enables the collection of model parameters from the clients and their transmission to the server. The parameters transmitted to the server are computed with the strategies mentioned in the Flower_Server section and returned to the clients throughset_model_parameters
. -
Returning to the Client side, the training of the model occurs as follows.
def fit(self, parameters, config):
# trains the model with the parameters received from the server
updated_parameters = self.train_model.train(initial_parameters=parameters)
return updated_parameters, len(self.train_model.train_dataloader().dataset), {}
Trained parameters are received and then returned along with the dataset’s length.
Links:
- https://research.google/pubs/federated-learning-strategies-for-improving-communication-efficiency/
- https://www.mdpi.com/2079-9292/12/10/2287
- https://flower.dev/docs/framework/how-to-use-strategies.html
- https://flower.dev/docs/framework/how-to-implement-strategies.html
- https://arxiv.org/abs/1703.03400